package com.yeskery.nut.http;

import com.yeskery.nut.core.*;
import com.yeskery.nut.util.StringUtils;

import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;

/**
 * 该类是基于表单二进制方式(multipart/form-data)提交的请求下的 {@link Request} 实现，
 * 该类的实现方式基于一个 {@link BaseRequest}对象。
 *
 * @author sprout
 * 2019-03-15 11:58
 * @version 1.0
 *
 * @see com.yeskery.nut.core.Request
 * @see BaseRequest
 */
public class FormDataRequest extends AbstractWrapperRequest implements Serializable, Cloneable {

	/** Http协议中的分隔符 */
	private static final String SEPARATE = "=";

	/** 表单结束行 */
	private static final String FORM_DATA_END_LINE = "--\r\n";

	/** 用于承载参数的集合 */
	private final Map<String, List<Object>> formParameters = new HashMap<>();

	/**
	 * 构建一个 {@link FormDataRequest} 对象
	 * @param baseRequest 基础数据对象 {@link BaseRequest} 对象
	 */
	public FormDataRequest(BaseRequest baseRequest) {
		super(baseRequest);
		initialize();
	}

	/** 初始化请求体参数信息 */
	private void initialize() {
		String contentType = getHeader(HttpHeader.CONTENT_TYPE);
		if (StringUtils.isEmpty(contentType)) {
			throw new NutException("Content Type Not Match multipart/form-data");
		}
		String[] contentTypeValues = contentType.split(";");
		if (contentTypeValues.length <= 1 || !contentTypeValues[1].contains(SEPARATE)) {
			throw new NutException("Content Type Not Match multipart/form-data");
		}
		getHeaders().put(HttpHeader.CONTENT_TYPE, Collections.singletonList(contentTypeValues[0].trim()));
		String boundary = "--" + contentTypeValues[1].substring(contentTypeValues[1].indexOf(SEPARATE) + 1);
		byte[] bodyBytes = getBody();
		byte[] boundaryBytes = boundary.getBytes();
		int index = 0;
		for (int i = 0;i < bodyBytes.length - boundaryBytes.length - 2;i++) {
			if (validBoundary(bodyBytes, i, boundaryBytes)) {
				if (index == i) {
					continue;
				}
				byte[] formBytes = new byte[i - index - boundaryBytes.length - 3];
				System.arraycopy(bodyBytes, index + boundaryBytes.length + BaseRequest.BOX_OFF.length
						, formBytes, 0, formBytes.length);
				initializeForm(formBytes);
				index = i;
			}
		}
		if (bodyBytes.length - index - boundaryBytes.length > 1) {
			if (bodyBytes.length - index - boundaryBytes.length == 4) {
				byte[] bytes = new byte[4];
				System.arraycopy(bodyBytes, bodyBytes.length - 4, bytes, 0, bytes.length);
				if (FORM_DATA_END_LINE.equals(new String(bytes, StandardCharsets.UTF_8))) {
					return;
				}
			}
			byte[] formBytes = new byte[bodyBytes.length - index - boundaryBytes.length - BaseRequest.BOX_OFF.length];
			System.arraycopy(bodyBytes, index + boundaryBytes.length + BaseRequest.BOX_OFF.length
					, formBytes, 0, formBytes.length);
			initializeForm(formBytes);
		}
	}

	/** 验证数据有效范围 */
	private boolean validBoundary(byte[] bytes, int index, byte[] boundaryBytes) {
		for (int i = 0; i < boundaryBytes.length; i++) {
			if (bytes[index + i] != boundaryBytes[i]) {
				return false;
			}
		}
		return true;
	}

	/** 初始化表单数据 */
	private void initializeForm(byte[] bytes) {
		int index = 0;
		for (int i = 0; i < bytes.length - BaseRequest.BOX_OFF.length; i++) {
			if (bytes[i] == BaseRequest.BOX_OFF[0] && bytes[i + 1] == BaseRequest.BOX_OFF[1]
					&& bytes[i + 2] == BaseRequest.BOX_OFF[0] && bytes[i + 3] == BaseRequest.BOX_OFF[1]) {
				index = i;
				break;
			}
		}
		if (index == 0) {
			return;
		}
		byte[] formKeys = new byte[index];
		int byteLength = bytes.length - index - (BaseRequest.BOX_OFF.length * 2);
		if (bytes[bytes.length - 1] == BaseRequest.BOX_OFF[0]) {
			byteLength -= 1;
		}
		byte[] formValues = new byte[byteLength];
		System.arraycopy(bytes, 0, formKeys, 0, formKeys.length);
		System.arraycopy(bytes, index + (BaseRequest.BOX_OFF.length * 2), formValues, 0, formValues.length);
		index = 0;
		for (int i = 0; i < formKeys.length - BaseRequest.BOX_OFF.length; i++) {
			if (formKeys[i] == BaseRequest.BOX_OFF[0] && formKeys[i + 1] == BaseRequest.BOX_OFF[1]) {
				index = i;
				break;
			}
		}
		if (index == 0) {
			String formKey = new String(formKeys, StandardCharsets.UTF_8);
			String[] tempFormKeys = formKey.split(";");
			if (tempFormKeys.length < 2) {
				return;
			}
			int pos = tempFormKeys[1].indexOf("\"");
			String key = tempFormKeys[1].substring(pos + 1, tempFormKeys[1].indexOf("\"", pos + 1));
			formParameters.computeIfAbsent(key, k -> new ArrayList<>()).add(formValues);
		} else {
			byte[] tempKey1 = new byte[index];
			byte[] tempKey2 = new byte[formKeys.length - index - BaseRequest.BOX_OFF.length];
			System.arraycopy(formKeys, 0, tempKey1, 0, tempKey1.length);
			System.arraycopy(formKeys, index + BaseRequest.BOX_OFF.length, tempKey2, 0, tempKey2.length);
			String formKey = new String(tempKey1, StandardCharsets.UTF_8);
			String[] tempFormKeys = formKey.split(";");
			if (tempFormKeys.length < 3) {
				return;
			}
			int pos = tempFormKeys[1].indexOf("\"");
			String key = tempFormKeys[1].substring(pos + 1, tempFormKeys[1].indexOf("\"", pos + 1));
			pos = tempFormKeys[2].indexOf("\"");
			String fileName = tempFormKeys[2].substring(pos + 1, tempFormKeys[2].indexOf("\"", pos + 1));
			String contentType = new String(tempKey2, StandardCharsets.UTF_8);
			contentType = contentType.split(":")[1];
			formParameters.computeIfAbsent(key, k -> new ArrayList<>()).add(new BasicMultipartFile(fileName, contentType, formValues));
		}
	}

	@Override
	public List<String> getHeaders(String key) {
		return getBaseRequest().getHeaders(key);
	}

	@Override
	public String getHeader(String key) {
		return getBaseRequest().getHeader(key);
	}

	@Override
	public Map<String, List<String>> getHeaders() {
		return getBaseRequest().getHeaders();
	}

	@Override
	public Set<String> getParameterKeys() {
		Set<String> keySet = new HashSet<>(getBaseRequest().getParameterKeys());
		keySet.addAll(formParameters.keySet());
		return keySet;
	}

	@Override
	public Method getMethod() {
		return getBaseRequest().getMethod();
	}

	@Override
	public String getOriginalPath() {
		return getBaseRequest().getOriginalPath();
	}

	@Override
	public String getPath() {
		return getBaseRequest().getPath();
	}

	@Override
	public String getProtocol() {
		return getBaseRequest().getProtocol();
	}

	@Override
	public String getRemoteAddress() {
		return getBaseRequest().getRemoteAddress();
	}

	@Override
	public String getRemoteHost() {
		return getBaseRequest().getRemoteHost();
	}

	@Override
	public int getRemotePort() {
		return getBaseRequest().getRemotePort();
	}

	@Override
	public String getLocalAddress() {
		return getBaseRequest().getLocalAddress();
	}

	@Override
	public String getLocalHost() {
		return getBaseRequest().getLocalHost();
	}

	@Override
	public int getLocalPort() {
		return getBaseRequest().getLocalPort();
	}

	@Override
	public ServerRequestConfiguration getServerRequestConfiguration() {
		return getBaseRequest().getServerRequestConfiguration();
	}

	@Override
	public Cookie[] getCookies() {
		return getBaseRequest().getCookies();
	}

	@Override
	public boolean hasCookie(String name) {
		return getBaseRequest().hasCookie(name);
	}

	@Override
	public List<MultipartFile> getFiles(String key) {
		List<Object> files = formParameters.get(key);
		if (files == null) {
			return null;
		}
		return files.stream().map(f -> f instanceof MultipartFile ? (MultipartFile) f : null).collect(Collectors.toList());
	}

	@Override
	public byte[] getBody() {
		return getBaseRequest().getBody();
	}

	@Override
	public List<String> getParameters(String key) {
		List<String> values = getBaseRequest().getParameters(key);
		if (values == null) {
			values = new ArrayList<>();
		}
		List<Object> objects = formParameters.get(key);
		if (objects != null && !objects.isEmpty()) {
			values.addAll(objects.stream().map(o -> o instanceof byte[] ? new String((byte[]) o) : o.toString()).collect(Collectors.toList()));
		}
		return values;
	}

	@Override
	public Map<String, List<String>> getParametersMap() {
		Map<String, List<String>> resultMap = getBaseRequest().getParametersMap();
		if (resultMap == null) {
			resultMap = new HashMap<>();
		}
		for (String key : formParameters.keySet()) {
			resultMap.computeIfAbsent(key, k -> new ArrayList<>()).addAll(getParameters(key));
		}
		return resultMap;
	}

	@Override
	public Map<String, List<String>> getQueryParametersMap() {
		return getBaseRequest().getQueryParametersMap();
	}

	@Override
	public void addAttribute(String name, Object value) {
		getBaseRequest().addAttribute(name, value);
	}

	@Override
	public void removeAttribute(String name) {
		getBaseRequest().removeAttribute(name);
	}

	@Override
	public Object getAttribute(String name) {
		return getBaseRequest().getAttribute(name);
	}

	@Override
	public Map<String, Object> getAttributes() {
		return getBaseRequest().getAttributes();
	}

	@Override
	public Session getSession() {
		return getBaseRequest().getSession();
	}

	@Override
	public ServerContext getServerContext() {
		return getBaseRequest().getServerContext();
	}

	@Override
	public boolean isEmpty() {
		return getBaseRequest().isEmpty();
	}

	/**
	 * 获取所有的请求参数集合
	 * @return 所有的请求参数集合
	 */
	public Map<String, List<Object>> getFormParametersMap() {
		return formParameters;
	}

	/**
	 * 获取所有的请求参数集合
	 * @return 所有的请求参数集合
	 */
	public Map<String, Object> getFormParameterMap() {
		return getFormParametersMap()
				.entrySet()
				.stream()
				.collect(Collectors.toMap(Map.Entry::getKey, e -> {
					List<Object> values = e.getValue();
					return values != null && !values.isEmpty() ? values.get(0) : "";
				}));
	}

	@Override
	public FormDataRequest clone() {
		try {
			return (FormDataRequest) super.clone();
		} catch (CloneNotSupportedException e) {
			throw new NutException(e);
		}
	}
}
